{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Point cloud classification with PointNet\n",
    "\n",
    "**Author:** [David Griffiths](https://dgriffiths3.github.io)<br>\n",
    "**Date created:** 2020/05/25<br>\n",
    "**Last modified:** 2020/05/26<br>\n",
    "**Description:** Implementation of PointNet for ModelNet10 classification."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Point cloud classification\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "Classification, detection and segmentation of unordered 3D point sets i.e. point clouds\n",
    "is a core problem in computer vision. This example implements the seminal point cloud\n",
    "deep learning paper [PointNet (Qi et al., 2017)](https://arxiv.org/abs/1612.00593). For a\n",
    "detailed intoduction on PointNet see [this blog\n",
    "post](https://medium.com/@luis_gonzales/an-in-depth-look-at-pointnet-111d7efdaa1a).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n",
    "\n",
    "If using colab first install trimesh with `!pip install trimesh`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "import os\n",
    "import glob\n",
    "import trimesh\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "tf.random.set_seed(1234)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Load dataset\n",
    "\n",
    "We use the ModelNet10 model dataset, the smaller 10 class version of the ModelNet40\n",
    "dataset. First download the data:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "DATA_DIR = tf.keras.utils.get_file(\n",
    "    \"modelnet.zip\",\n",
    "    \"http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip\",\n",
    "    extract=True,\n",
    ")\n",
    "DATA_DIR = os.path.join(os.path.dirname(DATA_DIR), \"ModelNet10\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "We can use the `trimesh` package to read and visualize the `.off` mesh files.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "mesh = trimesh.load(os.path.join(DATA_DIR, \"chair/train/chair_0001.off\"))\n",
    "mesh.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "To convert a mesh file to a point cloud we first need to sample points on the mesh\n",
    "surface. `.sample()` performs a unifrom random sampling. Here we sample at 2048 locations\n",
    "and visualize in `matplotlib`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "points = mesh.sample(2048)\n",
    "\n",
    "fig = plt.figure(figsize=(5, 5))\n",
    "ax = fig.add_subplot(111, projection=\"3d\")\n",
    "ax.scatter(points[:, 0], points[:, 1], points[:, 2])\n",
    "ax.set_axis_off()\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "To generate a `tf.data.Dataset()` we need to first parse through the ModelNet data\n",
    "folders. Each mesh is loaded and sampled into a point cloud before being added to a\n",
    "standard python list and converted to a `numpy` array. We also store the current\n",
    "enumerate index value as the object label and use a dictionary to recall this later.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def parse_dataset(num_points=2048):\n",
    "\n",
    "    train_points = []\n",
    "    train_labels = []\n",
    "    test_points = []\n",
    "    test_labels = []\n",
    "    class_map = {}\n",
    "    folders = glob.glob(os.path.join(DATA_DIR, \"[!README]*\"))\n",
    "\n",
    "    for i, folder in enumerate(folders):\n",
    "        print(\"processing class: {}\".format(os.path.basename(folder)))\n",
    "        # store folder name with ID so we can retrieve later\n",
    "        class_map[i] = folder.split(\"/\")[-1]\n",
    "        # gather all files\n",
    "        train_files = glob.glob(os.path.join(folder, \"train/*\"))\n",
    "        test_files = glob.glob(os.path.join(folder, \"test/*\"))\n",
    "\n",
    "        for f in train_files:\n",
    "            train_points.append(trimesh.load(f).sample(num_points))\n",
    "            train_labels.append(i)\n",
    "\n",
    "        for f in test_files:\n",
    "            test_points.append(trimesh.load(f).sample(num_points))\n",
    "            test_labels.append(i)\n",
    "\n",
    "    return (\n",
    "        np.array(train_points),\n",
    "        np.array(test_points),\n",
    "        np.array(train_labels),\n",
    "        np.array(test_labels),\n",
    "        class_map,\n",
    "    )\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Set the number of points to sample and batch size and parse the dataset. This can take\n",
    "~5minutes to complete.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "NUM_POINTS = 2048\n",
    "NUM_CLASSES = 10\n",
    "BATCH_SIZE = 32\n",
    "\n",
    "train_points, test_points, train_labels, test_labels, CLASS_MAP = parse_dataset(\n",
    "    NUM_POINTS\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Our data can now be read into a `tf.data.Dataset()` object. We set the shuffle buffer\n",
    "size to the entire size of the dataset as prior to this the data is ordered by class.\n",
    "Data augmentation is important when working with point cloud data. We create a\n",
    "augmentation function to jitter and shuffle the train dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def augment(points, label):\n",
    "    # jitter points\n",
    "    points += tf.random.uniform(points.shape, -0.005, 0.005, dtype=tf.float64)\n",
    "    # shuffle points\n",
    "    points = tf.random.shuffle(points)\n",
    "    return points, label\n",
    "\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices((train_points, train_labels))\n",
    "test_dataset = tf.data.Dataset.from_tensor_slices((test_points, test_labels))\n",
    "\n",
    "train_dataset = train_dataset.shuffle(len(train_points)).map(augment).batch(BATCH_SIZE)\n",
    "test_dataset = test_dataset.shuffle(len(test_points)).batch(BATCH_SIZE)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Build a model\n",
    "\n",
    "Each convolution and fully-connected layer (with exception for end layers) consits of\n",
    "Convolution / Dense -> Batch Normalization -> ReLU Activation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def conv_bn(x, filters):\n",
    "    x = layers.Conv1D(filters, kernel_size=1, padding=\"valid\")(x)\n",
    "    x = layers.BatchNormalization(momentum=0.0)(x)\n",
    "    return layers.Activation(\"relu\")(x)\n",
    "\n",
    "\n",
    "def dense_bn(x, filters):\n",
    "    x = layers.Dense(filters)(x)\n",
    "    x = layers.BatchNormalization(momentum=0.0)(x)\n",
    "    return layers.Activation(\"relu\")(x)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "PointNet consists of two core components. The primary MLP network, and the transformer\n",
    "net (T-net). The T-net aims to learn an affine transformation matrix by its own mini\n",
    "network. The T-net is used twice. The first time to transform the input features (n, 3)\n",
    "into a canonical representation. The second is an affine transformation for alignment in\n",
    "feature space (n, 3). As per the original paper we constrain the transformation to be\n",
    "close to an orthogonal matrix (i.e. ||X*X^T - I|| = 0).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "class OrthogonalRegularizer(keras.regularizers.Regularizer):\n",
    "    def __init__(self, num_features, l2reg=0.001):\n",
    "        self.num_features = num_features\n",
    "        self.l2reg = l2reg\n",
    "        self.eye = tf.eye(num_features)\n",
    "\n",
    "    def __call__(self, x):\n",
    "        x = tf.reshape(x, (-1, self.num_features, self.num_features))\n",
    "        xxt = tf.tensordot(x, x, axes=(2, 2))\n",
    "        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))\n",
    "        return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    " We can then define a general function to build T-net layers.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def tnet(inputs, num_features):\n",
    "\n",
    "    # Initalise bias as the indentity matrix\n",
    "    bias = keras.initializers.Constant(np.eye(num_features).flatten())\n",
    "    reg = OrthogonalRegularizer(num_features)\n",
    "\n",
    "    x = conv_bn(inputs, 32)\n",
    "    x = conv_bn(x, 64)\n",
    "    x = conv_bn(x, 512)\n",
    "    x = layers.GlobalMaxPooling1D()(x)\n",
    "    x = dense_bn(x, 256)\n",
    "    x = dense_bn(x, 128)\n",
    "    x = layers.Dense(\n",
    "        num_features * num_features,\n",
    "        kernel_initializer=\"zeros\",\n",
    "        bias_initializer=bias,\n",
    "        activity_regularizer=reg,\n",
    "    )(x)\n",
    "    feat_T = layers.Reshape((num_features, num_features))(x)\n",
    "    # Apply affine transformation to input features\n",
    "    return layers.Dot(axes=(2, 1))([inputs, feat_T])\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "The main network can be then implemented in the same manner where the t-net mini models\n",
    "can be dropped in a layers in the graph. Here we replicate the network architecture\n",
    "published in the original paper but with half the number of weights at each layer as we\n",
    "are using the smaller 10 class ModelNet dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "inputs = keras.Input(shape=(NUM_POINTS, 3))\n",
    "\n",
    "x = tnet(inputs, 3)\n",
    "x = conv_bn(x, 32)\n",
    "x = conv_bn(x, 32)\n",
    "x = tnet(x, 32)\n",
    "x = conv_bn(x, 32)\n",
    "x = conv_bn(x, 64)\n",
    "x = conv_bn(x, 512)\n",
    "x = layers.GlobalMaxPooling1D()(x)\n",
    "x = dense_bn(x, 256)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "x = dense_bn(x, 128)\n",
    "x = layers.Dropout(0.3)(x)\n",
    "\n",
    "outputs = layers.Dense(NUM_CLASSES, activation=\"softmax\")(x)\n",
    "\n",
    "model = keras.Model(inputs=inputs, outputs=outputs, name=\"pointnet\")\n",
    "model.summary()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Train model\n",
    "\n",
    "Once the model is defined it can be trained like any other standard classification model\n",
    "using `.compile()` and `.fit()`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    loss=\"sparse_categorical_crossentropy\",\n",
    "    optimizer=keras.optimizers.Adam(learning_rate=0.001),\n",
    "    metrics=[\"sparse_categorical_accuracy\"],\n",
    ")\n",
    "\n",
    "model.fit(train_dataset, epochs=20, validation_data=test_dataset)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Visualize predictions\n",
    "\n",
    "We can use matplotlib to visualize our trained model performance.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "data = test_dataset.take(1)\n",
    "\n",
    "points, labels = list(data)[0]\n",
    "points = points[:8, ...]\n",
    "labels = labels[:8, ...]\n",
    "\n",
    "# run test data through model\n",
    "preds = model.predict(points)\n",
    "preds = tf.math.argmax(preds, -1)\n",
    "\n",
    "points = points.numpy()\n",
    "\n",
    "# plot points with predicted class and label\n",
    "fig = plt.figure(figsize=(15, 10))\n",
    "for i in range(8):\n",
    "    ax = fig.add_subplot(2, 4, i + 1, projection=\"3d\")\n",
    "    ax.scatter(points[i, :, 0], points[i, :, 1], points[i, :, 2])\n",
    "    ax.set_title(\n",
    "        \"pred: {:}, label: {:}\".format(\n",
    "            CLASS_MAP[preds[i].numpy()], CLASS_MAP[labels.numpy()[i]]\n",
    "        )\n",
    "    )\n",
    "    ax.set_axis_off()\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "pointnet",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}